//+------------------------------------------------------------------+
//|                                            RandomForest Test.mq5 |
//|                                     Copyright 2023, Omega Joctan |
//|                        https://www.mql5.com/en/users/omegajoctan |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, Omega Joctan"
#property link      "https://www.mql5.com/en/users/omegajoctan"
#property version   "1.00"

//#define DEBUG_MODE

#include <forest.mqh>
#include <preprocessing.mqh>

CRandomForestClassifier *classifier_forest;
//CRandomForestRegressor *regressor_forest;

CMatrixutils matrix_utils;
CMetrics metrics;

//--- trading ibraries
#include <Trade\Trade.mqh>
#include <Trade\PositionInfo.mqh>

CTrade m_trade;
CPositionInfo m_position;

#define MAGICNUMBER 3122023

//How much bars to collect in the past for training purposes

input int   train_bars = 1000;

input group "Decision Tree";

input uint min_sample = 3;
input uint max_depth_ = 3;
input mode tree_mode = MODE_GINI;

input group "Random Forest";
input uint number_of_trees = 100;
input bool bootstrapping = true;

input group "Rsi";

input int rsi_period = 13;
input ENUM_APPLIED_PRICE rsi_applied =PRICE_CLOSE;

input group "Stochastic";

input int k_period = 5;
input int d_period = 3;
input int slowing = 3;
input ENUM_MA_METHOD stoch_mode = MODE_EMA;
input ENUM_STO_PRICE stoch_applied = STO_CLOSECLOSE;

input group "Trade Params";
input double stoploss = 300;
input double takeprofit = 350;
input uint slippage = 100;

bool train_once = false; //Training is computationally expensive, you might wanna train once during EA runtime

int rsi_handle, stoch_handle;

struct data{
   vector stoch_buff, 
          signal_buff, 
          rsi_buff, 
          target;
} data_struct;

vector x_vars(3); //Independent/feature variables
MqlTick ticks;
int prev_bars = 0; //keeping track of the new bar

//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
     
   train_once = false;
   
   rsi_handle = iRSI(Symbol(),PERIOD_CURRENT,rsi_period,rsi_applied);
   stoch_handle = iStochastic(Symbol(), PERIOD_CURRENT, k_period, d_period, slowing, stoch_mode, stoch_applied);

//---
   
   m_trade.SetExpertMagicNumber(MAGICNUMBER);
   m_trade.SetDeviationInPoints(slippage);
   m_trade.SetTypeFillingBySymbol(Symbol());
   m_trade.SetMarginMode();

   //TrainTree();

   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
//---
   while (CheckPointer(classifier_forest) != POINTER_INVALID)
     delete (classifier_forest);
          
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
//---
   if (!train_once)              // You want to train once during EA lifetime 
     TrainTree();
    train_once = true;
    
    if (isnewBar(PERIOD_CURRENT)) // We want to trade on the bar opening 
      {
        int signal = randomforestSignal();
        double min_lot = SymbolInfoDouble(Symbol(), SYMBOL_VOLUME_MIN);
        SymbolInfoTick(Symbol(), ticks);
        
         if (signal == -1)
           {
              if (!PosExists(MAGICNUMBER, POSITION_TYPE_SELL)) // If a sell trade doesnt exist
                m_trade.Sell(min_lot, Symbol(), ticks.bid, ticks.bid+stoploss*Point(), ticks.bid - takeprofit*Point());
           }
         else
           {
             if (!PosExists(MAGICNUMBER, POSITION_TYPE_BUY))  // If a buy trade doesnt exist
               m_trade.Buy(min_lot, Symbol(), ticks.ask, ticks.ask-stoploss*Point(), ticks.ask + takeprofit*Point());
           }
      }
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
void TrainTree()
 {
  matrix dataset(train_bars, 4);
  vector v;

//--- Collecting indicator buffers
   
  data_struct.rsi_buff.CopyIndicatorBuffer(rsi_handle, 0, 1, train_bars);
  data_struct.stoch_buff.CopyIndicatorBuffer(stoch_handle, 0, 1, train_bars);
  data_struct.signal_buff.CopyIndicatorBuffer(stoch_handle, 1, 1, train_bars);
  
//--- Preparing the target variable
   
  MqlRates rates[];
  ArraySetAsSeries(rates, true); 
  int size = CopyRates(Symbol(), PERIOD_CURRENT, 1,train_bars, rates);
  
  data_struct.target.Resize(size); //Resize the target vector
  
  for (int i=0; i<size; i++)
    {
      if (rates[i].close > rates[i].open)
        data_struct.target[i] = 1;
      else 
        data_struct.target[i] = -1;
    }
  
  dataset.Col(data_struct.rsi_buff, 0);
  dataset.Col(data_struct.stoch_buff, 1);
  dataset.Col(data_struct.signal_buff, 2);
  dataset.Col(data_struct.target, 3);
  
  matrix train_x, test_x;
  vector train_y, test_y;
  
  matrix_utils.TrainTestSplitMatrices(dataset, train_x, train_y, test_x, test_y, 0.8, 42); //split the data into training and testing samples
      
//--- Random Forest Classifier 

  classifier_forest = new CRandomForestClassifier(number_of_trees, min_sample, max_depth_);
  
  classifier_forest.fit(train_x, train_y, bootstrapping);
  
  
  vector preds = classifier_forest.predict(train_x); //making the predictions on a training data
  
  Print("Forest Train Acc = ",metrics.accuracy_score(train_y, preds)); //Measuring the accuracy  

//---

  preds = classifier_forest.predict(test_x); //making the predictions on a training data
  
  Print("Forest Test Acc = ",metrics.accuracy_score(test_y, preds)); //Measuring the accuracy  
  
 }
//+------------------------------------------------------------------+
//|Function to provide signals for live trading using decision  tree |
//+------------------------------------------------------------------+
int randomforestSignal()
 {
//--- Copy the current bar information only

   data_struct.rsi_buff.CopyIndicatorBuffer(rsi_handle, 0, 0, 1);
   data_struct.stoch_buff.CopyIndicatorBuffer(stoch_handle, 0, 0, 1);
   data_struct.signal_buff.CopyIndicatorBuffer(stoch_handle, 1, 0, 1);
   
   x_vars[0] = data_struct.rsi_buff[0];
   x_vars[1] = data_struct.stoch_buff[0];
   x_vars[2] = data_struct.signal_buff[0];
   
   return int(classifier_forest.predict(x_vars));
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool PosExists(int magic, ENUM_POSITION_TYPE type)
 {
    for (int i=PositionsTotal()-1; i>=0; i--)
      if (m_position.SelectByIndex(i))
         if (m_position.Symbol()==Symbol() && m_position.Magic() == magic && m_position.PositionType()==type)
            return (true);
            
    return (false);
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool isnewBar(ENUM_TIMEFRAMES TF)
 {
   if (prev_bars == 0)
      prev_bars = Bars(Symbol(), TF);
      
   
   if (prev_bars != Bars(Symbol(), TF))
    { 
      prev_bars = Bars(Symbol(), TF);
      return true;
    }
    
  return false;
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+

